{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "04975d41",
   "metadata": {},
   "source": [
    "# Query Top-N\n",
    "The following example show a way to extract with the ampligraph.discovery.query_topn API the top N candidates to be true positives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dafa4c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "from ampligraph.datasets import load_from_csv\n",
    "# Game of Thrones relations dataset\n",
    "url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/GoT.csv'\n",
    "open('GoT.csv', 'wb').write(requests.get(url).content)\n",
    "X = load_from_csv('.', 'GoT.csv', sep=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metal device set to: Apple M1 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x2849ac970>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# Create, compile and fit the model\n",
    "model = ScoringBasedEmbeddingModel(eta=5, \n",
    "                                   k=150,\n",
    "                                   scoring_type='DistMult')\n",
    "\n",
    "\n",
    "\n",
    "model.compile(optimizer='adam', loss='pairwise')\n",
    "model.fit(X,\n",
    "          batch_size=100,\n",
    "          epochs=20, \n",
    "          verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f531fdd4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([['Eddard Stark', 'ALLIED_WITH', 'House Stark of Winterfell'],\n",
       "        ['Eddard Stark', 'ALLIED_WITH', 'House Westerling of the Crag'],\n",
       "        ['Eddard Stark', 'ALLIED_WITH', \"House Baratheon of Storm's End\"],\n",
       "        ['Eddard Stark', 'ALLIED_WITH', 'House Frey of the Crossing'],\n",
       "        ['Eddard Stark', 'ALLIED_WITH', 'House Karstark of Karhold']],\n",
       "       dtype='<U44'),\n",
       " array([2.4362302 , 0.6104515 , 0.5265138 , 0.5220847 , 0.39194268],\n",
       "       dtype=float32))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from ampligraph.discovery import query_topn\n",
    "\n",
    "query_topn(model, top_n=5,\n",
    "           head='Eddard Stark', relation='ALLIED_WITH', tail=None,\n",
    "           ents_to_consider=None, rels_to_consider=None)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
